{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "N6ZDpd9XzFeN"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Hub Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "KUu4vOt5zI9d"
      },
      "outputs": [],
      "source": [
        "# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "# =============================================================================="
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "CxmDMK4yupqg"
      },
      "source": [
        "# Object Detection\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/hub/tutorials/object_detection\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/hub/blob/master/examples/colab/object_detection.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/hub/blob/master/examples/colab/object_detection.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/hub/examples/colab/object_detection.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Sy553YSVmYiK"
      },
      "source": [
        "This Colab demonstrates use of a TF-Hub module trained to perform object detection."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "v4XGxDrCkeip"
      },
      "source": [
        "## Setup\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "6cPY9Ou4sWs_"
      },
      "outputs": [],
      "source": [
        "#@title Imports and function definitions\n",
        "\n",
        "# For running inference on the TF-Hub module.\n",
        "import tensorflow as tf\n",
        "\n",
        "import tensorflow_hub as hub\n",
        "\n",
        "# For downloading the image.\n",
        "import matplotlib.pyplot as plt\n",
        "import tempfile\n",
        "from six.moves.urllib.request import urlopen\n",
        "from six import BytesIO\n",
        "\n",
        "# For drawing onto the image.\n",
        "import numpy as np\n",
        "from PIL import Image\n",
        "from PIL import ImageColor\n",
        "from PIL import ImageDraw\n",
        "from PIL import ImageFont\n",
        "from PIL import ImageOps\n",
        "\n",
        "# For measuring the inference time.\n",
        "import time\n",
        "\n",
        "# Print Tensorflow version\n",
        "print(tf.__version__)\n",
        "\n",
        "# Check available GPU devices.\n",
        "print(\"The following GPU devices are available: %s\" % tf.test.gpu_device_name())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ZGkrXGy62409"
      },
      "source": [
        "## Example use"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "vlA3CftFpRiW"
      },
      "source": [
        "### Helper functions for downloading images and for visualization.\n",
        "\n",
        "Visualization code adapted from [TF object detection API](https://github.com/tensorflow/models/blob/master/research/object_detection/utils/visualization_utils.py) for the simplest required functionality."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "D9IwDpOtpIHW"
      },
      "outputs": [],
      "source": [
        "def display_image(image):\n",
        "  fig = plt.figure(figsize=(20, 15))\n",
        "  plt.grid(False)\n",
        "  plt.imshow(image)\n",
        "\n",
        "\n",
        "def download_and_resize_image(url, new_width=256, new_height=256,\n",
        "                              display=False):\n",
        "  _, filename = tempfile.mkstemp(suffix=\".jpg\")\n",
        "  response = urlopen(url)\n",
        "  image_data = response.read()\n",
        "  image_data = BytesIO(image_data)\n",
        "  pil_image = Image.open(image_data)\n",
        "  pil_image = ImageOps.fit(pil_image, (new_width, new_height), Image.ANTIALIAS)\n",
        "  pil_image_rgb = pil_image.convert(\"RGB\")\n",
        "  pil_image_rgb.save(filename, format=\"JPEG\", quality=90)\n",
        "  print(\"Image downloaded to %s.\" % filename)\n",
        "  if display:\n",
        "    display_image(pil_image)\n",
        "  return filename\n",
        "\n",
        "\n",
        "def draw_bounding_box_on_image(image,\n",
        "                               ymin,\n",
        "                               xmin,\n",
        "                               ymax,\n",
        "                               xmax,\n",
        "                               color,\n",
        "                               font,\n",
        "                               thickness=4,\n",
        "                               display_str_list=()):\n",
        "  \"\"\"Adds a bounding box to an image.\"\"\"\n",
        "  draw = ImageDraw.Draw(image)\n",
        "  im_width, im_height = image.size\n",
        "  (left, right, top, bottom) = (xmin * im_width, xmax * im_width,\n",
        "                                ymin * im_height, ymax * im_height)\n",
        "  draw.line([(left, top), (left, bottom), (right, bottom), (right, top),\n",
        "             (left, top)],\n",
        "            width=thickness,\n",
        "            fill=color)\n",
        "\n",
        "  # If the total height of the display strings added to the top of the bounding\n",
        "  # box exceeds the top of the image, stack the strings below the bounding box\n",
        "  # instead of above.\n",
        "  display_str_heights = [font.getsize(ds)[1] for ds in display_str_list]\n",
        "  # Each display_str has a top and bottom margin of 0.05x.\n",
        "  total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)\n",
        "\n",
        "  if top \u003e total_display_str_height:\n",
        "    text_bottom = top\n",
        "  else:\n",
        "    text_bottom = top + total_display_str_height\n",
        "  # Reverse list and print from bottom to top.\n",
        "  for display_str in display_str_list[::-1]:\n",
        "    text_width, text_height = font.getsize(display_str)\n",
        "    margin = np.ceil(0.05 * text_height)\n",
        "    draw.rectangle([(left, text_bottom - text_height - 2 * margin),\n",
        "                    (left + text_width, text_bottom)],\n",
        "                   fill=color)\n",
        "    draw.text((left + margin, text_bottom - text_height - margin),\n",
        "              display_str,\n",
        "              fill=\"black\",\n",
        "              font=font)\n",
        "    text_bottom -= text_height - 2 * margin\n",
        "\n",
        "\n",
        "def draw_boxes(image, boxes, class_names, scores, max_boxes=10, min_score=0.1):\n",
        "  \"\"\"Overlay labeled boxes on an image with formatted scores and label names.\"\"\"\n",
        "  colors = list(ImageColor.colormap.values())\n",
        "\n",
        "  try:\n",
        "    font = ImageFont.truetype(\"/usr/share/fonts/truetype/liberation/LiberationSansNarrow-Regular.ttf\",\n",
        "                              25)\n",
        "  except IOError:\n",
        "    print(\"Font not found, using default font.\")\n",
        "    font = ImageFont.load_default()\n",
        "\n",
        "  for i in range(min(boxes.shape[0], max_boxes)):\n",
        "    if scores[i] \u003e= min_score:\n",
        "      ymin, xmin, ymax, xmax = tuple(boxes[i])\n",
        "      display_str = \"{}: {}%\".format(class_names[i].decode(\"ascii\"),\n",
        "                                     int(100 * scores[i]))\n",
        "      color = colors[hash(class_names[i]) % len(colors)]\n",
        "      image_pil = Image.fromarray(np.uint8(image)).convert(\"RGB\")\n",
        "      draw_bounding_box_on_image(\n",
        "          image_pil,\n",
        "          ymin,\n",
        "          xmin,\n",
        "          ymax,\n",
        "          xmax,\n",
        "          color,\n",
        "          font,\n",
        "          display_str_list=[display_str])\n",
        "      np.copyto(image, np.array(image_pil))\n",
        "  return image"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "D19UCu9Q2-_8"
      },
      "source": [
        "## Apply module\n",
        "\n",
        "Load a public image from Open Images v4, save locally, and display."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "YLWNhjUY1mhg"
      },
      "outputs": [],
      "source": [
        "# By Heiko Gorski, Source: https://commons.wikimedia.org/wiki/File:Naxos_Taverna.jpg\n",
        "image_url = \"https://upload.wikimedia.org/wikipedia/commons/6/60/Naxos_Taverna.jpg\"  #@param\n",
        "downloaded_image_path = download_and_resize_image(image_url, 1280, 856, True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "t-VdfLbC1w51"
      },
      "source": [
        "Pick an object detection module and apply on the downloaded image. Modules:\n",
        "* **FasterRCNN+InceptionResNet V2**: high accuracy,\n",
        "* **ssd+mobilenet V2**: small and fast."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "uazJ5ASc2_QE"
      },
      "outputs": [],
      "source": [
        "module_handle = \"https://tfhub.dev/google/faster_rcnn/openimages_v4/inception_resnet_v2/1\" #@param [\"https://tfhub.dev/google/openimages_v4/ssd/mobilenet_v2/1\", \"https://tfhub.dev/google/faster_rcnn/openimages_v4/inception_resnet_v2/1\"]\n",
        "\n",
        "detector = hub.load(module_handle).signatures['default']"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "znW8Fq1EC0x7"
      },
      "outputs": [],
      "source": [
        "def load_img(path):\n",
        "  img = tf.io.read_file(path)\n",
        "  img = tf.image.decode_jpeg(img, channels=3)\n",
        "  return img"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "kwGJV96WWBLH"
      },
      "outputs": [],
      "source": [
        "def run_detector(detector, path):\n",
        "  img = load_img(path)\n",
        "\n",
        "  converted_img  = tf.image.convert_image_dtype(img, tf.float32)[tf.newaxis, ...]\n",
        "  start_time = time.time()\n",
        "  result = detector(converted_img)\n",
        "  end_time = time.time()\n",
        "\n",
        "  result = {key:value.numpy() for key,value in result.items()}\n",
        "\n",
        "  print(\"Found %d objects.\" % len(result[\"detection_scores\"]))\n",
        "  print(\"Inference time: \", end_time-start_time)\n",
        "\n",
        "  image_with_boxes = draw_boxes(\n",
        "      img.numpy(), result[\"detection_boxes\"],\n",
        "      result[\"detection_class_entities\"], result[\"detection_scores\"])\n",
        "\n",
        "  display_image(image_with_boxes)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "vchaUW1XDodD"
      },
      "outputs": [],
      "source": [
        "run_detector(detector, downloaded_image_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "WUUY3nfRX7VF"
      },
      "source": [
        "### More images\n",
        "Perform inference on some additional images with time tracking.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "rubdr2JXfsa1"
      },
      "outputs": [],
      "source": [
        "image_urls = [\n",
        "  # Source: https://commons.wikimedia.org/wiki/File:The_Coleoptera_of_the_British_islands_(Plate_125)_(8592917784).jpg\n",
        "  \"https://upload.wikimedia.org/wikipedia/commons/1/1b/The_Coleoptera_of_the_British_islands_%28Plate_125%29_%288592917784%29.jpg\",\n",
        "  # By Américo Toledano, Source: https://commons.wikimedia.org/wiki/File:Biblioteca_Maim%C3%B3nides,_Campus_Universitario_de_Rabanales_007.jpg\n",
        "  \"https://upload.wikimedia.org/wikipedia/commons/thumb/0/0d/Biblioteca_Maim%C3%B3nides%2C_Campus_Universitario_de_Rabanales_007.jpg/1024px-Biblioteca_Maim%C3%B3nides%2C_Campus_Universitario_de_Rabanales_007.jpg\",\n",
        "  # Source: https://commons.wikimedia.org/wiki/File:The_smaller_British_birds_(8053836633).jpg\n",
        "  \"https://upload.wikimedia.org/wikipedia/commons/0/09/The_smaller_British_birds_%288053836633%29.jpg\",\n",
        "  ]\n",
        "\n",
        "def detect_img(image_url):\n",
        "  start_time = time.time()\n",
        "  image_path = download_and_resize_image(image_url, 640, 480)\n",
        "  run_detector(detector, image_path)\n",
        "  end_time = time.time()\n",
        "  print(\"Inference time:\",end_time-start_time)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "otPnrxMKIrj5"
      },
      "outputs": [],
      "source": [
        "detect_img(image_urls[0])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "H5F7DkD5NtOx"
      },
      "outputs": [],
      "source": [
        "detect_img(image_urls[1])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "DZ18R7dWNyoU"
      },
      "outputs": [],
      "source": [
        "detect_img(image_urls[2])"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "Object detection",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
